Skip to main content

Retrieval Augmented Generation

The model for answering the prompts may require specific data to answer the queries. For example, if asked a question about forecasts, the model requires the actual forecast data.

This knowledge is provided to the model with Retrieval Augmented Generation. RAG could be performed with the llama_index framework or AzureOpenAI itself if having access to a dedicated API service for embeddings, but since I was working with tabular data in a structured format, simply using langchain_experimental.sql.SQLDatabaseChain was enough.

Even simpler is to use an llm to generate an SQL query (text-to-sql) to answer a prompt, then execute the generated query to fetch the data.

SQLDatabaseChain

Essentially, I converted the csv file into an SQL database, then used SQLDatabaseChain to run the model on the database, answering any questions with results from appropriate SQL qeries to the database.

import pandas as pd
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
import sqlite3

forecast_df = pd.read_csv("data/forecast.csv")
conn = sqlite3.connect('Forecast.sqlite')
c = conn.cursor()

create_query = """
CREATE TABLE IF NOT EXISTS Forecast(
date TEXT
brand TEXT
Actual TEXT
Forecast TEXT
IBP TEXT
ML error TEXT
IBP error TEXT
Improvement TEXT
)
"""
c.execute(create_query)
conn.commit()
forecast_df.to_sql('Forecast', conn, if_exists="replace", index=False)

template = f"""
---
<Instructions>
---

If asked to process a question, use the following format:

Question: question here
SQLQuery: SQL query to run to process the question.
Answer: Result of the SQL query. This is the output.

---
##Context
<history of intermediate responses>
---

---
<Examples>
---

Only use the following tables:
{table_info}

Question: {input}
"""

prompt = PromptTemplate(input_variables=['input', 'table_info'], template=template)

forecast_db = SQLDatabase.from_uri('sqlite:///Forecast.sqlite')

db_chain = SQLDatabaseChain.from_llm(llm=llm, db=forecast_db, prompt=prompt, verbose=True)
response = db_chain.run("What is the total improvement for all forecasts?")

db_chain runs an appropriate SQL query on the Forecast.sqlite table and computes the total improvement, then responds accordingly.

The individual tasks from the cleaned prompt, if detected to be of a type requiring to RAG a specific dataset, are sent to this tool by the agent.

Manual

In the manual approach, I first use an llm to generate an SQL query:

def text_to_sql(llm, prompt):
prompt_template = PromptTemplate(
template="""
Given the following tables, your job is to write queries given the user's request.
Your response will be directly fed into a function for SQL query execution, so don't respond with anything other than the query.
Do not perform cross joins since that will return huge amount of data.

---
### Database Schema
This query will run on the following tables with the following columns:
<table description>
---

### Request
{input}
""", input_variables=["input"]
)

response = llm.invoke(prompt_template.format(input=prompt)).content
print("query prompts: ", response)
return response

This is then executed on the database containing the tables:

import pandas as pd
import sqlite3

inventory_df = pd.read_csv("data/inventory.csv")
sales_df = pd.read_csv("data/delivery_data.csv")
forecast_df = pd.read_csv("data/final.csv")

conn = sqlite3.connect('database.sqlite')
c = conn.cursor()

inventory_df.to_sql('Inventory', conn, if_exists="replace", index=False)
sales_df.to_sql('Sales', conn, if_exists="replace", index=False)
forecast_df.to_sql('Forecast', conn, if_exists="replace", index=False)

def process_sql(query):
conn = sqlite3.connect(f'data/database.sqlite')
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
conn.close()
return results

prompt = text_to_sql(llm, query)
process_sql(prompt)